# CelebA image generation using Conditional DCGAN
import os

import torch

from Image_Mediator_Training.Image_Mediator_Evaluation import imageMediatorEvaluation
from Image_Mediator_Training.imageMediator_graph import set_imageMediator
from ModularUtils.FunctionsConstant import load_label_dataset, asKey, initialize_results, load_image_dataset
from ModularUtils.ControllerConstants import get_multiple_labels_fill, fill2d_to_fill4d
from ModularUtils.ControllerModel import get_generators, get_discriminators, get_generated_labels
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsTraining import get_training_variables, labels_image_gradient_penalty, calc_gradient_penalty, \
    save_checkpoint, top_sort_list
from Image_Mediator_Training.MNISTVae import train_encoders


def train_CausalController(Exp, cur_mechs, label_generators, G_optimizers, img_discriminator, label_discriminator, D_optimizer1, D_optimizer2,
                           dataset_dict_batches, batchno):

    G_loss=torch.zeros(1).to(Exp.DEVICE)
    for interv_no, (intv_key, dataset_batches) in enumerate(dataset_dict_batches.items()):
        intv_key = dict(intv_key)

        data_input = dataset_batches["obs"][batchno]

        _,_,_, graph_label_vars = get_training_variables(Exp, Exp.label_names, interv_no, intv_key)
        all_compare_Var, _, intervened_Var, real_labels_vars = get_training_variables(Exp, cur_mechs, interv_no, intv_key)

        isImage= set(cur_mechs) & set(Exp.image_labels) != set()
        isRep= set(all_compare_Var) & set(Exp.rep_labels) != set()

        #fix it later
        if len(real_labels_vars)>data_input.shape[1]:
            continue

        mini_batch = data_input.size()[0]
        indices = [graph_label_vars.index(lb) for lb in real_labels_vars]
        current_real_label = data_input[:, indices].type(torch.LongTensor).view(-1, len(indices)).to(Exp.DEVICE)
        # current_real_label = data_input.type(torch.LongTensor).view(-1, len(real_labels_vars)).to(Exp.DEVICE)
        dims_list = [Exp.label_dim[lb] for lb in real_labels_vars]

        obs_images=None
        if isImage & isRep:
            real_labels_fill = get_multiple_labels_fill(Exp, current_real_label, dims_list, isImage_labels=False)
            real_labels_fill = torch.cat([real_labels_fill, dataset_batches["rep"][batchno]], 1)   #['medD', 'medC', 'RI']  Keeping RI at the end
            real_digits=  real_labels_fill[:,0:2].unsqueeze(2).unsqueeze(3).repeat(1, 1, Exp.IMAGE_SIZE, Exp.IMAGE_SIZE)
            obs_images = dataset_batches["img"][batchno]


        intv_tensor_dict = {}
        isClassifier=False

        generated_image=None
        # if isImage & isRep:
        gen_labels= top_sort_list(intervened_Var + all_compare_Var, Exp.label_names)
        generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_tensor_dict, gen_labels, mini_batch, hard=True)
        generated_image = generated_labels_dict[Exp.image_labels[0]]
        del generated_labels_dict[Exp.image_labels[0]]

        y_dims = sum([Exp.label_dim[lb] for lb in real_labels_vars + Exp.rep_labels])
        ret = list(generated_labels_dict.values())
        generated_labels_fill = torch.cat(ret, 1).view(-1, y_dims)
        gen_digits = fill2d_to_fill4d(Exp, generated_labels_fill[:,0:2], more_dimsize=Exp.IMAGE_SIZE)


        D_losses = []
        image_loss=[]
        label_loss=[]
        for crit_ in range(Exp.CRITIC_ITERATIONS):


            ################ I need multiple discriminators ###########



            D_real_img = img_discriminator[interv_no](obs_images, real_digits).squeeze()
            D_fake_img = img_discriminator[interv_no](generated_image, gen_digits).squeeze()
            gp_img = labels_image_gradient_penalty(img_discriminator[interv_no], obs_images, real_digits, generated_image, gen_digits, isClassifier,
                                       device=Exp.DEVICE)
            D_loss_img = (-  (torch.mean(D_real_img) - torch.mean(D_fake_img)) + Exp.LAMBDA_GP * gp_img)   #Matching P(I|D) vs Q(I|D)


            D_real_obs = label_discriminator[interv_no](real_labels_fill).squeeze()
            D_fake_obs = label_discriminator[interv_no](generated_labels_fill).squeeze()
            gp_obs = calc_gradient_penalty(label_discriminator[interv_no], real_labels_fill, generated_labels_fill, device=Exp.DEVICE)
            D_loss_obs = (-  (torch.mean(D_real_obs) - torch.mean(D_fake_obs)) + Exp.LAMBDA_GP * gp_obs)  #P(D,RI,C)


            D_totalLoss= D_loss_img + D_loss_obs
            D_losses.append((D_loss_obs).data+ (D_loss_img).data)  # just a loss list

            img_discriminator[interv_no].zero_grad()
            label_discriminator[interv_no].zero_grad()

            D_totalLoss.backward(retain_graph=True)
            D_optimizer1[interv_no].step()  #Update critic for P(I|D) and P(D,RI,C)
            D_optimizer2[interv_no].step()

        # accumulating the generator losses for all interventions.
        D_fake_img = img_discriminator[interv_no](generated_image, gen_digits).squeeze()

        D_fake_obs = label_discriminator[interv_no](generated_labels_fill).squeeze()

        G_loss += -torch.mean(D_fake_img) -torch.mean(D_fake_obs)   # Loss for updating generators of P(I|D) and P(D,RI,C)


    # Back propagation
    for mech in cur_mechs:
        label_generators[mech].zero_grad()

    G_loss.backward()

    for mech in cur_mechs:
        G_optimizers[mech].step()

    D_loss = torch.mean(torch.FloatTensor(D_losses))  # just mean of losses


    image_loss, label_loss = torch.zeros(1).to(Exp.DEVICE), torch.zeros(1).to(Exp.DEVICE)

    return G_loss.data, D_loss.data, image_loss, label_loss




def labelMain(Exp, cur_hnodes, label_generators, G_optimizers, discriminators, D_optimizers, dataset_dict,
              tvd_diff, kl_diff):
    dataset_dict_batches = {}

    num_batches=0
    for key, each_dataset in dataset_dict.items():
        dataset_dict_batches[key]={}
        real_dataloader = torch.utils.data.DataLoader(dataset=each_dataset["obs"],
                                                      batch_size=Exp.batch_size,
                                                      shuffle=False)

        batch_list = []
        for data_input in real_dataloader:
            data_input = torch.squeeze(data_input)
            if len(data_input.size())==1:
                data_input= data_input.view(-1,1)
            batch_list.append(data_input)

        dataset_dict_batches[key]["obs"] = batch_list
        num_batches = len(batch_list)

        ####
        if len(Exp.rep_labels):
            real_dataloader = torch.utils.data.DataLoader(dataset=each_dataset["rep"],
                                                          batch_size=Exp.batch_size,
                                                          shuffle=False)

            batch_list = []
            for data_input in real_dataloader:
                data_input = torch.squeeze(data_input)
                if len(data_input.size())==1:
                    data_input= data_input.view(-1,1)
                batch_list.append(data_input)

            dataset_dict_batches[key]["rep"] = batch_list
            num_batches = len(batch_list)


        if len(Exp.image_labels):
            image_data_loader = torch.utils.data.DataLoader(dataset=each_dataset["img"],
                                                            batch_size=Exp.batch_size,
                                                            shuffle=False)
            batch_list = []
            for data_input in image_data_loader:
                data_input = torch.squeeze(data_input)
                batch_list.append(data_input)
            dataset_dict_batches[key]["img"] = batch_list

    iteration = 0

    for batchno in range(num_batches):

        for hn, cur_mechs in cur_hnodes.items():

            g_loss, d_loss,  image_loss, label_loss = train_CausalController(Exp, cur_mechs, label_generators, G_optimizers, discriminators['H1'], discriminators['H2'],
                                                    D_optimizers['H1'], D_optimizers['H2'],  dataset_dict_batches, batchno)

            print('Epoch [%d/%d], Step [%d/%d],' % (
                Exp.curr_epoochs + 1, Exp.num_epochs, iteration + 1, num_batches),
              'mechanism: ',cur_mechs,  ' D_loss: %.4f, G_loss: %.4f' % (d_loss.data, g_loss.data))



        # Annealing
        tot_iter = Exp.curr_epoochs * num_batches + iteration
        if tot_iter % 100 == 0:
            Exp.anneal_temperature(tot_iter)



        Exp.D_avg_losses.append(torch.mean(d_loss))
        Exp.G_avg_losses.append(torch.mean(g_loss))
        iteration += 1

        # break
    #
    if (Exp.curr_epoochs + 1) % 1 == 0:
        print("Turn on caffeinate or these results are gone!")
        tvd_diff, kl_diff = imageMediatorEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff)
#
    # if (Exp.curr_epoochs <= 50 and (Exp.curr_epoochs + 1) % 5 == 0) or (Exp.curr_epoochs > 50 and (Exp.curr_epoochs + 1) % 15 == 0):
    if (Exp.curr_epoochs + 1) % 5 == 0:
        var_list= "".join(x for x in cur_mechs)
        save_checkpoint(Exp, Exp.SAVED_PATH, cur_mechs, label_generators, G_optimizers, {var_list:discriminators}, {var_list: D_optimizers})
        print(Exp.curr_epoochs,":model saved at ", Exp.SAVED_PATH)

    return




if __name__ == "__main__":

    Exp = Experiment("Exp1", set_imageMediator,
                     Temperature=1,
                     temp_min=0.1,
                     G_hid_dims=[256, 256],
                     D_hid_dims=[256, 256],
                     # IMAGE_FILTERS=[512, 256, 128],
                     IMAGE_FILTERS=[128, 64, 32],
                     CRITIC_ITERATIONS=1,
                     LAMBDA_GP=10,
                     learning_rate=5 * 1e-4,
                     Synthetic_Sample_Size=20000,
                     intv_Sample_Size=20000,
                     batch_size=200,
                     features=["feature"],
                     noise_states=64,
                     latent_state=4,
                     ENCODED_DIM=10,
                     Data_intervs=[{}],
                     num_epochs=300,
                     new_experiment=True
                     )


    print(Exp.Data_intervs)
    Exp.intv_batch_size = Exp.batch_size
    # True scm

    os.makedirs(Exp.SAVED_PATH, exist_ok=True)
    dag_name = Exp.Complete_DAG_desc + ".txt"

    # Load previous model results also
    # Exp.LOAD_MODEL_PATH = "/path_to_project/SAVED_EXPERIMENTS/imageMediator/Exp1/May_10_2023-07_01"
    Exp.LOAD_MODEL_PATH = "/path_to_project/SAVED_EXPERIMENTS/imageMediator/Exp1/May_10_2023-07_39run_while"
    # Exp.load_which_models = {"D": True, "I": True ,"RI":True ,  "C":True}
    Exp.load_which_models = {"medD": True, "I": True ,"RI":True ,  "medC":True}

    Exp.new_experiment=True
    isModular= False
    train_enc=False

    comment=""
    if len(comment)!=0:
        Exp.SAVED_PATH+=comment


    # cur_hnodes = {"H1":["I"], "H2": ["medD","medC"]}
    cur_hnodes = {"H1":["I"], "H2": ["medD","medC"]}



    # Changed I to RI in the joint distribution
    Exp.train_mech_dict["medD"] = [{'parents': [], 'intv': {}, 'compare': ['medD',  'RI',  'medC']}]
    Exp.train_mech_dict["I"] = [{'parents': ['medD'], 'intv': {'medD'}, 'compare': ['I']}]
    Exp.train_mech_dict["medC"] = [{'parents': ['I'], 'intv': {}, 'compare': ['medD', 'RI','medC']}]


    Exp.LAMBDA_GP=10
    label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)
    discriminatorsMech, doptimizersMech = get_discriminators(Exp, cur_hnodes, Exp.load_which_models)  #



    cur_hnodes = {"H1": ["medD", "I", "medC"]}  # joint training but including encoder
    image_data_dict = load_image_dataset(Exp, cur_hnodes)
    dataset_dict = load_label_dataset(Exp, image_data_dict, label_generators, cur_hnodes)
    dataset_dict[asKey({})]["img"]=image_data_dict[asKey({})]


    # shuffling data: low-dim variables, image and representation
    randices = torch.randperm(Exp.Synthetic_Sample_Size)
    dataset_dict[asKey({})]['obs'] = dataset_dict[asKey({})]['obs'][randices]
    dataset_dict[asKey({})]['img'] = dataset_dict[asKey({})]['img'][randices]
    dataset_dict[asKey({})]['rep'] = dataset_dict[asKey({})]['rep'][randices]


    if train_enc==True:
        for rep_mech in Exp.rep_labels:
            train_encoders(Exp, rep_mech, label_generators, optimizersMech, image_data_dict, dataset_dict)


    tvd_diff, kl_diff = initialize_results(Exp, cur_hnodes)
    mech_tvd = 0
    print("Starting training new mechanism")




    for epoch in range(Exp.num_epochs):
        Exp.curr_epoochs = epoch
        labelMain(Exp, cur_hnodes, label_generators, optimizersMech, discriminatorsMech, doptimizersMech, dataset_dict, tvd_diff, kl_diff)

